{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hard Negatives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hard negatives are those negative samples that are particularly challenging for the model to distinguish from the positive ones. They are often close to the decision boundary or exhibit features that make them highly similar to the positive samples. Thus hard negative mining is widely used in machine learning tasks to make the model focus on subtle differences between similar instances, leading to better discrimination.\n",
    "\n",
    "In text retrieval system, a hard negative could be document that share some feature similarities with the query but does not truly satisfy the query's intent. During retrieval, those documents could rank higher than the real answers. Thus it's valuable to explicitly train the model on these hard negatives."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, load an embedding model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import FlagModel\n",
    "\n",
    "model = FlagModel('BAAI/bge-base-en-v1.5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, load the queries and corpus from dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "corpus = load_dataset(\"BeIR/scifact\", \"corpus\")[\"corpus\"]\n",
    "queries = load_dataset(\"BeIR/scifact\", \"queries\")[\"queries\"]\n",
    "\n",
    "corpus_ids = corpus.select_columns([\"_id\"])[\"_id\"]\n",
    "corpus = corpus.select_columns([\"text\"])[\"text\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create a dictionary maping auto generated ids (starting from 0) used by FAISS index, for later use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus_ids_map = {}\n",
    "for i in range(len(corpus)):\n",
    "    corpus_ids_map[i] = corpus_ids[i]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Indexing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the embedding model to encode the queries and corpus:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pre tokenize: 100%|██████████| 21/21 [00:00<00:00, 46.18it/s]\n",
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n",
      "Inference Embeddings:   0%|          | 0/21 [00:00<?, ?it/s]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:   5%|▍         | 1/21 [00:49<16:20, 49.00s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  10%|▉         | 2/21 [01:36<15:10, 47.91s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  14%|█▍        | 3/21 [02:16<13:23, 44.66s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  19%|█▉        | 4/21 [02:52<11:39, 41.13s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  24%|██▍       | 5/21 [03:23<09:58, 37.38s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  29%|██▊       | 6/21 [03:55<08:51, 35.44s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  33%|███▎      | 7/21 [04:24<07:47, 33.37s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  38%|███▊      | 8/21 [04:51<06:49, 31.51s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  43%|████▎     | 9/21 [05:16<05:52, 29.37s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  48%|████▊     | 10/21 [05:42<05:13, 28.51s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  52%|█████▏    | 11/21 [06:05<04:25, 26.59s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  57%|█████▋    | 12/21 [06:26<03:43, 24.85s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  62%|██████▏   | 13/21 [06:45<03:06, 23.35s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  67%|██████▋   | 14/21 [07:04<02:33, 21.89s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  71%|███████▏  | 15/21 [07:21<02:03, 20.54s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  76%|███████▌  | 16/21 [07:38<01:36, 19.30s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  81%|████████  | 17/21 [07:52<01:11, 17.87s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  86%|████████▌ | 18/21 [08:06<00:49, 16.58s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  90%|█████████ | 19/21 [08:18<00:30, 15.21s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings:  95%|█████████▌| 20/21 [08:28<00:13, 13.56s/it]Attempting to cast a BatchEncoding to type None. This is not supported.\n",
      "Inference Embeddings: 100%|██████████| 21/21 [08:29<00:00, 24.26s/it]\n"
     ]
    }
   ],
   "source": [
    "p_vecs = model.encode(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5183, 768)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_vecs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then create a FAISS index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, faiss\n",
    "import numpy as np\n",
    "\n",
    "# create a basic flat index with dimension match our embedding\n",
    "index = faiss.IndexFlatIP(len(p_vecs[0]))\n",
    "# make sure the embeddings are float32\n",
    "p_vecs = np.asarray(p_vecs, dtype=np.float32)\n",
    "# use gpu to accelerate index searching\n",
    "if torch.cuda.is_available():\n",
    "    co = faiss.GpuMultipleClonerOptions()\n",
    "    co.shard = True\n",
    "    co.useFloat16 = True\n",
    "    index = faiss.index_cpu_to_all_gpus(index, co=co)\n",
    "# add all the embeddings to the index\n",
    "index.add(p_vecs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Searching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For better demonstration, let's use a single query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'_id': '0',\n",
       " 'title': '',\n",
       " 'text': '0-dimensional biomaterials lack inductive properties.'}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query = queries[0]\n",
    "query"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the id and content of that query, then use our embedding model to get its embedding vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_id, q_text = query[\"_id\"], query[\"text\"]\n",
    "# use the encode_queries() function to encode query\n",
    "q_vec = model.encode_queries(queries=q_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the index to search for closest results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['4346436',\n",
       " '17388232',\n",
       " '14103509',\n",
       " '37437064',\n",
       " '29638116',\n",
       " '25435456',\n",
       " '32532238',\n",
       " '31715818',\n",
       " '23763738',\n",
       " '7583104',\n",
       " '21456232',\n",
       " '2121272',\n",
       " '35621259',\n",
       " '58050905',\n",
       " '196664003']"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, ids = index.search(np.expand_dims(q_vec, axis=0), k=15)\n",
    "# convert the auto ids back to ids in the original dataset\n",
    "converted = [corpus_ids_map[id] for id in ids[0]]\n",
    "converted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'query-id': 0, 'corpus-id': 31715818, 'score': 1}"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qrels = load_dataset(\"BeIR/scifact-qrels\")[\"train\"]\n",
    "pos_id = qrels[0]\n",
    "pos_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we use the mothod of top-k shifted by N, which get the top 10 negatives after rank 5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['25435456',\n",
       " '32532238',\n",
       " '23763738',\n",
       " '7583104',\n",
       " '21456232',\n",
       " '2121272',\n",
       " '35621259',\n",
       " '58050905',\n",
       " '196664003']"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "negatives = [id for id in converted[5:] if int(id) != pos_id[\"corpus-id\"]]\n",
    "negatives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have select a group of hard negatives for the first query!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are other methods to refine the process of choosing hard negatives. For example, the [implementation](https://github.com/FlagOpen/FlagEmbedding/blob/master/scripts/hn_mine.py) in our GitHub repo get the top 200 shifted by 10, which mean top 10-210. And then sample 15 from the 200 candidates. The reason is directly choosing the top K may introduce some false negatives, passages that somehow relative to the query but not exactly the answer to that query, into the negative set. This could influence model's performance."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
